import re
from typing import List, Dict
import logging
import json


def extract_code_block(text):
    """
    Extracts the content inside triple backtick code blocks from a text.

    Args:
        text (str): The text to extract the code block from.

    Returns:
        str: The content of the first code block if any are found, otherwise the raw text.
    """
    # Regex pattern to match triple backtick code blocks (with optional language hint)
    pattern = re.compile(r"```(?:\w+)?\n(.*?)```", re.DOTALL)

    # Find all matches
    code_blocks = pattern.findall(text)

    # Return first match or raw text
    return code_blocks[0] if code_blocks else text


def standardize_tool_call(tool_call: dict) -> dict | None:
    """
    Standardizes the format of tool calls according to the format expected by OpenAI.

    Args:
        tool_call (dict): The tool call to validate.

    Returns:
        dict | None: Standardized tool call if valid, None otherwise.
    """
    # Ensure the tool call has a "name"
    standardized_tool_call = {}
    if "name" in tool_call:
        standardized_tool_call["name"] = tool_call["name"]
    else:
        logging.warning("Tool call does not have a 'name' field.")
        return None

    # Ensure the tool call has "arguments"
    if "arguments" in tool_call:
        standardized_tool_call["arguments"] = tool_call["arguments"]
    elif "parameters" in tool_call:
        standardized_tool_call["arguments"] = tool_call["parameters"]
    else:
        logging.warning("Tool call does not have a 'arguments' or 'parameters' field.")
        return None

    return standardized_tool_call


def extract_tool_calls(
    text: str, added_tokens_decoder: List[str]
) -> tuple[List[Dict], str]:
    """
    Extracts tool calls from generated text based on tool calling identifiers.

    Args:
        text (str): The text output generated by the model.
        added_tokens_decoder (List[str]): The list of tokens in the tokenizer.added_tokens_decoder.

    Returns:
        tuple[List[Dict], str]: A tuple containing:
            - List[Dict]: A list of extracted tool call objects (raw JSON-like dicts)
            - str: The original text with tool calls removed
    """
    matches = []
    special_tokens = [v.content for v in added_tokens_decoder.values()]

    # Pattern 1: <tool_call>...</tool_call> block
    # Sample model that uses this pattern: Qwen3-8B
    if "<tool_call>" in special_tokens and "</tool_call>" in special_tokens:
        tool_call_pattern = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
        matches = list(tool_call_pattern.finditer(text))

    # Pattern 2: [TOOL_CALLS] [ {...} ] block
    # Sample model that uses this pattern: Mistral-7B-Instruct-v0.3
    elif "[TOOL_CALLS]" in special_tokens:
        tool_call_pattern = re.compile(
            r"\[TOOL_CALLS\]\s*\[(.*?)\](?=\s*<|/?eos|$)", re.DOTALL
        )
        matches = list(tool_call_pattern.finditer(text))

    else:
        logging.warning(
            "Tool calling identifiers were not found for the current model."
        )

    # Some models don't use any tool calling identifiers.
    # Instead, tool calls are identified by only generating JSON content.
    # Sample model that uses this pattern: Llama-3.1-8B-Instruct
    try:
        # Remove the json for a code block if needed
        parsed_text = extract_code_block(text)
        json_tool_calls = json.loads(parsed_text)

        if isinstance(json_tool_calls, dict):
            json_tool_calls = [json_tool_calls]

        extracted_tool_calls = []
        for tool_call in json_tool_calls:
            # Return the tool call if all calls are valid
            standard_tool_call = standardize_tool_call(tool_call)
            if standard_tool_call is not None:
                extracted_tool_calls.append(standard_tool_call)
            else:
                return [], text

        return extracted_tool_calls, ""

    except json.JSONDecodeError:
        pass

    # Process matches in reverse to avoid position shifting
    extracted_tool_calls = []
    cleaned_text = text
    for match in reversed(matches):
        content = match.group(1).strip()
        json_tool_call = None
        try:
            json_tool_call = json.loads(content)
        except json.JSONDecodeError:
            logging.warning("Could not parse tool call as JSON.")
            continue

        # Attempt to standardize the tool call
        standard_tool_call = standardize_tool_call(json_tool_call)
        if standard_tool_call is None:
            continue

        # If the content is a valid JSON object, add it to the list
        extracted_tool_calls.append(standard_tool_call)

        # Remove the matched tool call from the text
        cleaned_text = cleaned_text[: match.start()] + cleaned_text[match.end() :]

    return extracted_tool_calls, cleaned_text.strip()
